import os
import random

from absl import app
from absl import flags
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

LABEL_COL = {
  'colorectal_histology': 'label',  # (150, 150, 3)
  'deep_weeds': 'label',  # (256, 256, 3), currently unavailable
  'dmlab': 'label',  # (360, 480, 3), 3 GB
  'fashion_mnist': 'label',  # (28, 28, 1)
  'horses_or_humans': 'label',  # (300, 300, 3), currently unavailable
  'mnist': 'label',  # (28, 28)
  'omniglot': 'label',  # (105, 105, 3)
  'places365_small': 'label',  # (256, 256, 3), 30 GB
  'smallnorb': 'label_category',  # (96, 96, 1)
}

flags.DEFINE_enum('dataset', 'mnist', LABEL_COL.keys(), 'Dataset')
flags.DEFINE_integer('seed', 2023, 'Random seed')
flags.DEFINE_integer('epochs', 6, 'Epochs')
flags.DEFINE_integer('batch_size', 128, 'Batch size')
flags.DEFINE_float('learning_rate', 0.001, 'Learning rate')
flags.DEFINE_float('plateau_factor', 0.5, 'Plateau factor')
flags.DEFINE_list('hidden_neurons', None, 'Number of hidden neurons')
flags.DEFINE_string('fisher_file', 'fisher.npy', 'File to write Fisher matrix')
flags.DEFINE_string('weight_file', 'weight.npy', 'File to write weight matrix')
FLAGS = flags.FLAGS
logger = tf.get_logger()


def empirical_fisher(model, ds, weight_index=0) -> np.ndarray:
  """Computes the empirical fisher matrix of a model."""
  loss = tf.keras.losses.SparseCategoricalCrossentropy(
    from_logits=True, reduction=tf.keras.losses.Reduction.NONE
  )
  count = 0
  fisher = np.zeros(model.trainable_weights[weight_index].shape)
  for x, y in ds:
    with tf.GradientTape() as tape:
      y_true = tf.expand_dims(y, axis=-1)
      y_pred = model(x)
      example_loss = loss(y_true, y_pred)
    count += y.shape[0]
    gradient = tape.jacobian(example_loss, model.trainable_weights[weight_index])
    gradient = gradient * gradient
    gradient = tf.math.reduce_sum(gradient, axis=0)
    fisher += gradient.numpy()
  fisher /= count
  weight = model.trainable_weights[weight_index]
  return fisher, weight


def main(argv) -> None:
  del argv
  # set random seeds
  os.environ['PYTHONHASHSEED'] = str(FLAGS.seed)
  random.seed(FLAGS.seed)
  np.random.seed(FLAGS.seed)
  tf.random.set_seed(FLAGS.seed)
  tf.keras.utils.set_random_seed(FLAGS.seed)
  tf.config.experimental.enable_op_determinism()

  split = ['train', 'test']
  if FLAGS.dataset == 'colorectal_histology':
    split = [
      tfds.Split.TRAIN.subsplit(tfds.percent[:80]),
      tfds.Split.TRAIN.subsplit(tfds.percent[80:]),
    ]
  (ds_train, ds_test), ds_info = tfds.load(
      FLAGS.dataset,
      split=split,
      shuffle_files=True,
      as_supervised=True,
      with_info=True,
  )
  def normalize_img(image, label):
    """Normalizes images: `uint8` -> `float32`."""
    return tf.cast(image, tf.float32) / 255., label

  ds_train = ds_train.map(
      normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
  ds_train = ds_train.cache()
  ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
  ds_train = ds_train.batch(FLAGS.batch_size)
  ds_train = ds_train.prefetch(tf.data.AUTOTUNE)

  ds_test = ds_test.map(
      normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
  ds_test = ds_test.batch(FLAGS.batch_size)
  ds_test = ds_test.cache()
  ds_test = ds_test.prefetch(tf.data.AUTOTUNE)

  label_col = LABEL_COL[FLAGS.dataset]
  hidden_neurons = [int(num) for num in FLAGS.hidden_neurons]
  model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    *[tf.keras.layers.Dense(num, activation='relu') for num in hidden_neurons],
    tf.keras.layers.Dense(ds_info.features[label_col].num_classes)
  ])
  loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
  model.compile(
      optimizer=tf.keras.optimizers.Adam(FLAGS.learning_rate),
      loss=loss,
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
  )

  callbacks = [
    tf.keras.callbacks.EarlyStopping(restore_best_weights=True, patience=10)
  ]
  if FLAGS.plateau_factor is not None:
    callbacks + [tf.keras.callbacks.ReduceLROnPlateau(factor=FLAGS.plateau_factor)]
  model.fit(
      ds_train,
      epochs=FLAGS.epochs,
      validation_data=ds_test,
      callbacks=callbacks,
  )
  results = model.evaluate(ds_test, return_dict=True)
  logger.info(results)

  fisher, weight = empirical_fisher(model, ds_train)
  np.save(FLAGS.fisher_file, fisher)
  np.save(FLAGS.weight_file, weight)

if __name__ == '__main__':
  app.run(main)